{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4JlLTP1Y-WHg"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "if-ujOZN-Par"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uq9kCbELjzgJ"
      },
      "source": [
        "# Efficient serving\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/efficient_serving\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/efficient_serving.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/efficient_serving.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/efficient_serving.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UlFcUNXT7hSF"
      },
      "source": [
        "[Retrieval models](https://www.tensorflow.org/recommenders/examples/basic_retrieval) are often built to surface a handful of top candidates out of millions or even hundreds of millions of candidates. To be able to react to the user's context and behaviour, they need to be able to do this on the fly, in a matter of milliseconds.\n",
        "\n",
        "Approximate nearest neighbour search (ANN) is the technology that makes this possible. In this tutorial, we'll show how to use ScaNN - a state of the art nearest neighbour retrieval package - to seamlessly scale TFRS retrieval to millions of items."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q_s_2UgUWA9u"
      },
      "source": [
        "## What is ScaNN?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GSvmiDQPsGmb"
      },
      "source": [
        "ScaNN is a library from Google Research that performs dense vector similarity search at large scale. Given a database of candidate embeddings, ScaNN indexes these embeddings in a manner that allows them to be rapidly searched at inference time. ScaNN uses state of the art vector compression techniques and carefully implemented algorithms to achieve the best speed-accuracy tradeoff. It can greatly outperform brute force search while sacrificing little in terms of accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bTpnORU7WEPD"
      },
      "source": [
        "## Building a ScaNN-powered model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zXEZ3lZnWIVh"
      },
      "source": [
        "To try out ScaNN in TFRS, we'll build a simple MovieLens retrieval model, just as we did in the [basic retrieval](https://www.tensorflow.org/recommenders/examples/basic_retrieval) tutorial. If you have followed that tutorial, this section will be familiar and can safely be skipped.\n",
        "\n",
        "To start, install TFRS and TensorFlow Datasets:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mD2hiRviCxFE"
      },
      "outputs": [],
      "source": [
        "!pip install -q tensorflow-recommenders\n",
        "!pip install -q --upgrade tensorflow-datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEbc-66nDJzc"
      },
      "source": [
        "We also need to install `scann`: it's an optional dependency of TFRS, and so needs to be installed separately."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "daEivxsJDO0Y"
      },
      "outputs": [],
      "source": [
        "!pip install -q scann"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bDe054pgDQdp"
      },
      "source": [
        "Set up all the necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ekaJkcuHsiY"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, Text\n",
        "\n",
        "import os\n",
        "import pprint\n",
        "import tempfile\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WdTPCz136mvc"
      },
      "outputs": [],
      "source": [
        "import tensorflow_recommenders as tfrs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DfmRuUgJWlEQ"
      },
      "source": [
        "And load the data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k-VF30hJn5-3"
      },
      "outputs": [],
      "source": [
        "# Load the MovieLens 100K data.\n",
        "ratings = tfds.load(\n",
        "    \"movielens/100k-ratings\",\n",
        "    split=\"train\"\n",
        ")\n",
        "\n",
        "# Get the ratings data.\n",
        "ratings = (ratings\n",
        "           # Retain only the fields we need.\n",
        "           .map(lambda x: {\"user_id\": x[\"user_id\"], \"movie_title\": x[\"movie_title\"]})\n",
        "           # Cache for efficiency.\n",
        "           .cache(tempfile.NamedTemporaryFile().name)\n",
        ")\n",
        "\n",
        "# Get the movies data.\n",
        "movies = tfds.load(\"movielens/100k-movies\", split=\"train\")\n",
        "movies = (movies\n",
        "          # Retain only the fields we need.\n",
        "          .map(lambda x: x[\"movie_title\"])\n",
        "          # Cache for efficiency.\n",
        "          .cache(tempfile.NamedTemporaryFile().name))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SiVuNZ-lWv0R"
      },
      "source": [
        "Before we can build a model, we need to set up the user and movie vocabularies:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jw-iQKBBajnz"
      },
      "outputs": [],
      "source": [
        "user_ids = ratings.map(lambda x: x[\"user_id\"])\n",
        "\n",
        "unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))\n",
        "unique_user_ids = np.unique(np.concatenate(list(user_ids.batch(1000))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yRbZCvWHWzPU"
      },
      "source": [
        "We'll also set up the training and test sets:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FqV8p7N8CrEg"
      },
      "outputs": [],
      "source": [
        "tf.random.set_seed(42)\n",
        "shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)\n",
        "\n",
        "train = shuffled.take(80_000)\n",
        "test = shuffled.skip(80_000).take(20_000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ok3-kzr1bI7U"
      },
      "source": [
        "### Model definition\n",
        "\n",
        "Just as in the [basic retrieval](https://www.tensorflow.org/recommenders/examples/basic_retrieval) tutorial, we build a simple two-tower model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yX_j4pEVbKIS"
      },
      "outputs": [],
      "source": [
        "class MovielensModel(tfrs.Model):\n",
        "\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "\n",
        "    embedding_dimension = 32\n",
        "\n",
        "    # Set up a model for representing movies.\n",
        "    self.movie_model = tf.keras.Sequential([\n",
        "      tf.keras.layers.StringLookup(\n",
        "        vocabulary=unique_movie_titles, mask_token=None),\n",
        "      # We add an additional embedding to account for unknown tokens.\n",
        "      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)\n",
        "    ])\n",
        "\n",
        "    # Set up a model for representing users.\n",
        "    self.user_model = tf.keras.Sequential([\n",
        "      tf.keras.layers.StringLookup(\n",
        "        vocabulary=unique_user_ids, mask_token=None),\n",
        "        # We add an additional embedding to account for unknown tokens.\n",
        "      tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)\n",
        "    ])\n",
        "\n",
        "    # Set up a task to optimize the model and compute metrics.\n",
        "    self.task = tfrs.tasks.Retrieval(\n",
        "      metrics=tfrs.metrics.FactorizedTopK(\n",
        "        candidates=(\n",
        "            movies\n",
        "            .batch(128)\n",
        "            .cache()\n",
        "            .map(lambda title: (title, self.movie_model(title)))\n",
        "        )\n",
        "      )\n",
        "    )\n",
        "\n",
        "  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -\u003e tf.Tensor:\n",
        "    # We pick out the user features and pass them into the user model.\n",
        "    user_embeddings = self.user_model(features[\"user_id\"])\n",
        "    # And pick out the movie features and pass them into the movie model,\n",
        "    # getting embeddings back.\n",
        "    positive_movie_embeddings = self.movie_model(features[\"movie_title\"])\n",
        "\n",
        "    # The task computes the loss and the metrics.\n",
        "\n",
        "    return self.task(\n",
        "        user_embeddings,\n",
        "        positive_movie_embeddings,\n",
        "        candidate_ids=features[\"movie_title\"],\n",
        "        compute_metrics=not training\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JtO3lKR_XKkw"
      },
      "source": [
        "### Fitting and evaluation\n",
        "\n",
        "A TFRS model is just a Keras model. We can compile it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uOGTdwAAbuB6"
      },
      "outputs": [],
      "source": [
        "model = MovielensModel()\n",
        "model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4rGLyo-XXPmX"
      },
      "source": [
        "Estimate it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uf_E4dIMcGnk"
      },
      "outputs": [],
      "source": [
        "model.fit(train.batch(8192), epochs=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7xymbWgVXSrT"
      },
      "source": [
        "And evaluate it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EMlIj741cIT8"
      },
      "outputs": [],
      "source": [
        "model.evaluate(test.batch(8192), return_dict=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3RbHiBWqsFmf"
      },
      "source": [
        "## Approximate prediction\n",
        "\n",
        "The most straightforward way of retrieving top candidates in response to a query is to do it via brute force: compute user-movie scores for all possible movies, sort them, and pick a couple of top recommendations.\n",
        "\n",
        "In TFRS, this is accomplished via the `BruteForce` layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_L2yAPjpHsk"
      },
      "outputs": [],
      "source": [
        "brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model)\n",
        "brute_force.index_from_dataset(\n",
        "    movies.batch(128).map(lambda title: (title, model.movie_model(title)))\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CzoNR28vXw7o"
      },
      "source": [
        "Once created and populated with candidates (via the `index` method), we can call it to get predictions out:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SBo1Nu0Grife"
      },
      "outputs": [],
      "source": [
        "# Get predictions for user 42.\n",
        "_, titles = brute_force(np.array([\"42\"]), k=3)\n",
        "\n",
        "print(f\"Top recommendations: {titles[0]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AzNECPifr6i6"
      },
      "source": [
        "On a small dataset of under 1000 movies, this is very fast:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w57iyu7Ir87Q"
      },
      "outputs": [],
      "source": [
        "%timeit _, titles = brute_force(np.array([\"42\"]), k=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u2AjJsdrsClR"
      },
      "source": [
        "But what happens if we have more candidates - millions instead of thousands?\n",
        "\n",
        "We can simulate this by indexing all of our movies multiple times:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AapJk84csTqV"
      },
      "outputs": [],
      "source": [
        "# Construct a dataset of movies that's 1,000 times larger. We \n",
        "# do this by adding several million dummy movie titles to the dataset.\n",
        "lots_of_movies = tf.data.Dataset.concatenate(\n",
        "    movies.batch(4096),\n",
        "    movies.batch(4096).repeat(1_000).map(lambda x: tf.zeros_like(x))\n",
        ")\n",
        "\n",
        "# We also add lots of dummy embeddings by randomly perturbing\n",
        "# the estimated embeddings for real movies.\n",
        "lots_of_movies_embeddings = tf.data.Dataset.concatenate(\n",
        "    movies.batch(4096).map(model.movie_model),\n",
        "    movies.batch(4096).repeat(1_000)\n",
        "      .map(lambda x: model.movie_model(x))\n",
        "      .map(lambda x: x * tf.random.uniform(tf.shape(x)))\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "viCLP9qSYBQh"
      },
      "source": [
        "We can build a `BruteForce` index on this larger dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mfY62oQbYA3Z"
      },
      "outputs": [],
      "source": [
        "brute_force_lots = tfrs.layers.factorized_top_k.BruteForce()\n",
        "brute_force_lots.index_from_dataset(\n",
        "    tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OrkMt8O_xm-s"
      },
      "source": [
        "The recommendations are still the same"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I9fIYUeYxjki"
      },
      "outputs": [],
      "source": [
        "_, titles = brute_force_lots(model.user_model(np.array([\"42\"])), k=3)\n",
        "\n",
        "print(f\"Top recommendations: {titles[0]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wwF25ZzdseX8"
      },
      "source": [
        "But they take much longer. With a candidate set of 1 million movies, brute force prediction becomes quite slow:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oetK_wNxsdw0"
      },
      "outputs": [],
      "source": [
        "%timeit _, titles = brute_force_lots(model.user_model(np.array([\"42\"])), k=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mKF9yEeotbXQ"
      },
      "source": [
        "As the number of candidate grows, the amount of time needed grows linearly: with 10 million candidates, serving top candidates would take 250 milliseconds. This is clearly too slow for a live service.\n",
        "\n",
        "This is where approximate mechanisms come in.\n",
        "\n",
        "Using ScaNN in TFRS is accomplished via the `tfrs.layers.factorized_top_k.ScaNN` layer. It follow the same interface as the other top k layers:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SLgPmA90sbDL"
      },
      "outputs": [],
      "source": [
        "scann = tfrs.layers.factorized_top_k.ScaNN(\n",
        "    num_reordering_candidates=500,\n",
        "    num_leaves_to_search=30\n",
        ")\n",
        "scann.index_from_dataset(\n",
        "    tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qRI-qv7S2h97"
      },
      "source": [
        "The recommendations are (approximately!) the same"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HCkRn1VnxuXn"
      },
      "outputs": [],
      "source": [
        "_, titles = scann(model.user_model(np.array([\"42\"])), k=3)\n",
        "\n",
        "print(f\"Top recommendations: {titles[0]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iW1oBtcC2mb1"
      },
      "source": [
        "But they are much, much faster to compute:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ooJsLhpWstlf"
      },
      "outputs": [],
      "source": [
        "%timeit _, titles = scann(model.user_model(np.array([\"42\"])), k=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zOYk0zi12q-0"
      },
      "source": [
        "In this case, we can retrieve the top 3 movies out of a set of ~1 million in around 2 milliseconds: 15 times faster than by computing the best candidates via brute force. The advantage of approximate methods grows even larger for larger datasets."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tE7eL7ZzDKtl"
      },
      "source": [
        "## Evaluating the approximation\n",
        "\n",
        "When using approximate top K retrieval mechanisms (such as ScaNN), speed of retrieval often comes at the expense of accuracy. To understand this trade-off, it's important to measure the model's evaluation metrics when using ScaNN, and to compare them with the baseline.\n",
        "\n",
        "Fortunately, TFRS makes this easy. We simply override the metrics on the retrieval task with metrics using ScaNN, re-compile the model, and run evaluation.\n",
        "\n",
        "To make the comparison, let's first run baseline results. We still need to override our metrics to make sure they are using the enlarged candidate set rather than the original set of movies:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZZtJRQqBep_5"
      },
      "outputs": [],
      "source": [
        "# Override the existing streaming candidate source.\n",
        "model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(\n",
        "    candidates=tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")\n",
        "# Need to recompile the model for the changes to take effect.\n",
        "model.compile()\n",
        "\n",
        "%time baseline_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHFzcS5cQtB_"
      },
      "source": [
        "We can do the same using ScaNN:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5T-YxOqoKMje"
      },
      "outputs": [],
      "source": [
        "model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(\n",
        "    candidates=scann\n",
        ")\n",
        "model.compile()\n",
        "\n",
        "# We can use a much bigger batch size here because ScaNN evaluation\n",
        "# is more memory efficient.\n",
        "%time scann_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y0gnjcUUZ6-v"
      },
      "source": [
        "ScaNN based evaluation is much, much quicker. This advantage is going to grow even larger for bigger datasets, and so for large datasets it may be prudent to always run ScaNN-based evaluation to improve model development velocity.\n",
        "\n",
        "But how about the results? Fortunately, in this case the results are almost the same:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gcXUbZx3Fq4f"
      },
      "outputs": [],
      "source": [
        "print(f\"Brute force top-100 accuracy: {baseline_result['factorized_top_k/top_100_categorical_accuracy']:.2f}\")\n",
        "print(f\"ScaNN top-100 accuracy:       {scann_result['factorized_top_k/top_100_categorical_accuracy']:.2f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d2UJR-5nZ6YT"
      },
      "source": [
        "This suggests that on this artificial datase,  there is little loss from the approximation. In general, all approximate methods exhibit speed-accuracy tradeoffs. To understand this in more depth you can check out Erik Bernhardsson's [ANN benchmarks](https://github.com/erikbern/ann-benchmarks)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-jdPPOlV3JOr"
      },
      "source": [
        "## Deploying the approximate model\n",
        "\n",
        "The `ScaNN`-based model is fully integrated into TensorFlow models, and serving it is as easy as serving any other TensorFlow model.\n",
        "\n",
        "We can save it as a `SavedModel` object"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eKxXVfJBLbiW"
      },
      "outputs": [],
      "source": [
        "lots_of_movies_embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KnVI_6N53WU5"
      },
      "outputs": [],
      "source": [
        "# We re-index the ScaNN layer to include the user embeddings in the same model.\n",
        "# This way we can give the saved model raw features and get valid predictions\n",
        "# back.\n",
        "scann = tfrs.layers.factorized_top_k.ScaNN(model.user_model, num_reordering_candidates=1000)\n",
        "scann.index_from_dataset(\n",
        "    tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")\n",
        "\n",
        "# Need to call it to set the shapes.\n",
        "_ = scann(np.array([\"42\"]))\n",
        "\n",
        "with tempfile.TemporaryDirectory() as tmp:\n",
        "  path = os.path.join(tmp, \"model\")\n",
        "  tf.saved_model.save(\n",
        "      scann,\n",
        "      path,\n",
        "      options=tf.saved_model.SaveOptions(namespace_whitelist=[\"Scann\"])\n",
        "  )\n",
        "\n",
        "  loaded = tf.saved_model.load(path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O5vDZjro4lXG"
      },
      "source": [
        "and then load it and serve, getting exactly the same results back:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TXm8smCt3iFB"
      },
      "outputs": [],
      "source": [
        "_, titles = loaded(tf.constant([\"42\"]))\n",
        "\n",
        "print(f\"Top recommendations: {titles[0][:3]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S0Doal2ETqU4"
      },
      "source": [
        "The resulting model can be served in any Python service that has TensorFlow and ScaNN installed.\n",
        "\n",
        "It can also be served using a customized version of TensorFlow Serving, available as a Docker container on [Docker Hub](https://hub.docker.com/r/google/tf-serving-scann). You can also build the image yourself from the [Dockerfile](https://github.com/google-research/google-research/tree/master/scann/tf_serving)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0gQsvn5PYbR-"
      },
      "source": [
        "## Tuning ScaNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "918uqacB7sNH"
      },
      "source": [
        "Now let's look into tuning our ScaNN layer to get a better performance/accuracy tradeoff. In order to do this effectively, we first need to measure our baseline performance and accuracy.\n",
        "\n",
        "From above, we already have a measurement of our model's latency for processing a single (non-batched) query (although note that a fair amount of this latency is from non-ScaNN components of the model).\n",
        "\n",
        "Now we need to investigate ScaNN's accuracy, which we measure through recall. A recall@k of x% means that if we use brute force to retrieve the true top k neighbors, and compare those results to using ScaNN to also retrieve the top k neighbors, x% of ScaNN's results are in the true brute force results. Let's compute the recall for the current ScaNN searcher.\n",
        "\n",
        "First, we need to generate the brute force, ground truth top-k:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qgf_QuP-8EXb"
      },
      "outputs": [],
      "source": [
        "# Process queries in groups of 1000; processing them all at once with brute force\n",
        "# may lead to out-of-memory errors, because processing a batch of q queries against\n",
        "# a size-n dataset takes O(nq) space with brute force.\n",
        "titles_ground_truth = tf.concat([\n",
        "  brute_force_lots(queries, k=10)[1] for queries in\n",
        "  test.batch(1000).map(lambda x: model.user_model(x[\"user_id\"]))\n",
        "], axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSZkWESc856P"
      },
      "source": [
        "Our variable `titles_ground_truth` now contains the top-10 movie recommendations returned by brute-force retrieval. Now we can compute the same recommendations when using ScaNN:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yUKtdf1X87mP"
      },
      "outputs": [],
      "source": [
        "# Get all user_id's as a 1d tensor of strings\n",
        "test_flat = np.concatenate(list(test.map(lambda x: x[\"user_id\"]).batch(1000).as_numpy_iterator()), axis=0)\n",
        "\n",
        "# ScaNN is much more memory efficient and has no problem processing the whole\n",
        "# batch of 20000 queries at once.\n",
        "_, titles = scann(test_flat, k=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JTsTDiAZ9F6h"
      },
      "source": [
        "Next, we define our function that computes recall. For each query, it counts how many results are in the intersection of the brute force and the ScaNN results and divides this by the number of brute force results. The average of this quantity over all queries is our recall."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PCtBew2C9Gv0"
      },
      "outputs": [],
      "source": [
        "def compute_recall(ground_truth, approx_results):\n",
        "  return np.mean([\n",
        "      len(np.intersect1d(truth, approx)) / len(truth)\n",
        "      for truth, approx in zip(ground_truth, approx_results)\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_tdxlKua9JR2"
      },
      "source": [
        "This gives us baseline recall@10 with the current ScaNN config:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nMi4VtJD9K9P"
      },
      "outputs": [],
      "source": [
        "print(f\"Recall: {compute_recall(titles_ground_truth, titles):.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gKpgkNseYWW8"
      },
      "source": [
        "We can also measure the baseline latency:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "81mO-GS4VJLJ"
      },
      "outputs": [],
      "source": [
        "%timeit -n 1000 scann(np.array([\"42\"]), k=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UICnYQln9PAq"
      },
      "source": [
        "Let's see if we can do better!\n",
        "\n",
        "To do this, we need a model of how ScaNN's tuning knobs affect performance. Our current model uses ScaNN's tree-AH algorithm. This algorithm partitions the database of embeddings (the \"tree\") and then scores the most promising of these partitions using AH, which is a highly optimized approximate distance computation routine.\n",
        "\n",
        "The default parameters for TensorFlow Recommenders' ScaNN Keras layer sets `num_leaves=100` and `num_leaves_to_search=10`. This means our database is partitioned into 100 disjoint subsets, and the 10 most promising of these partitions is scored with AH. This means 10/100=10% of the dataset is being searched with AH.\n",
        "\n",
        "If we have, say, `num_leaves=1000` and `num_leaves_to_search=100`, we would also be searching 10% of the database with AH. However, in comparison to the previous setting, the 10% we would search will contain higher-quality candidates, because a higher `num_leaves` allows us to make finer-grained decisions about what parts of the dataset are worth searching.\n",
        "\n",
        "It's no surprise then that with `num_leaves=1000` and `num_leaves_to_search=100` we get significantly higher recall:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vq6L1Qtl9Qan"
      },
      "outputs": [],
      "source": [
        "scann2 = tfrs.layers.factorized_top_k.ScaNN(\n",
        "    model.user_model, \n",
        "    num_leaves=1000,\n",
        "    num_leaves_to_search=100,\n",
        "    num_reordering_candidates=1000)\n",
        "scann2.index_from_dataset(\n",
        "    tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")\n",
        "\n",
        "_, titles2 = scann2(test_flat, k=10)\n",
        "\n",
        "print(f\"Recall: {compute_recall(titles_ground_truth, titles2):.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G2WR8zPH9TtW"
      },
      "source": [
        "However, as a tradeoff, our latency has also increased. This is because the partitioning step has gotten more expensive; `scann` picks the top 10 of 100 partitions while `scann2` picks the top 100 of 1000 partitions. The latter can be more expensive because it involves looking at 10 times as many partitions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Po0kb4Mf9VhX"
      },
      "outputs": [],
      "source": [
        "%timeit -n 1000 scann2(np.array([\"42\"]), k=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fCDzY0sc9Zgc"
      },
      "source": [
        "In general, tuning ScaNN search is about picking the right tradeoffs. Each individual parameter change generally won't make search both faster and more accurate; our goal is to tune the parameters to optimally trade off between these two conflicting goals.\n",
        "\n",
        "In our case, `scann2` significantly improved recall over `scann` at some cost in latency. Can we dial back some other knobs to cut down on latency, while preserving most of our recall advantage?\n",
        "\n",
        "Let's try searching 70/1000=7% of the dataset with AH, and only rescoring the final 400 candidates:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jBp8Yvdj9pMQ"
      },
      "outputs": [],
      "source": [
        "scann3 = tfrs.layers.factorized_top_k.ScaNN(\n",
        "    model.user_model,\n",
        "    num_leaves=1000,\n",
        "    num_leaves_to_search=70,\n",
        "    num_reordering_candidates=400)\n",
        "scann3.index_from_dataset(\n",
        "    tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))\n",
        ")\n",
        "\n",
        "_, titles3 = scann3(test_flat, k=10)\n",
        "print(f\"Recall: {compute_recall(titles_ground_truth, titles3):.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Isgpm7b9rgE"
      },
      "source": [
        "`scann3` delivers about a 3% absolute recall gain over `scann` while also delivering lower latency:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JiDEWwtr9sKG"
      },
      "outputs": [],
      "source": [
        "%timeit -n 1000 scann3(np.array([\"42\"]), k=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NwWKyQgt9uh1"
      },
      "source": [
        "These knobs can be further adjusted to optimize for different points along the accuracy-performance pareto frontier. ScaNN's algorithms can achieve state-of-the-art performance over a wide range of recall targets."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UvlCsKyFU40k"
      },
      "source": [
        "## Further reading"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ikGqmNa9yRG"
      },
      "source": [
        "ScaNN uses advanced vector quantization techniques and highly optimized implementation to achieve its results. The field of vector quantization has a rich history with a variety of approaches. ScaNN's current quantization technique is detailed in [this paper](https://arxiv.org/abs/1908.10396), published at ICML 2020. The paper was also released along with [this blog article](https://ai.googleblog.com/2020/07/announcing-scann-efficient-vector.html) which gives a high level overview of our technique.\n",
        "\n",
        "Many related quantization techniques are mentioned in the references of our ICML 2020 paper, and other ScaNN-related research is listed at http://sanjivk.com/."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "efficient_serving.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
